#!/usr/bin/env python3
"""
PoXiao大模型训练脚本
"""

import os
import sys

# 添加项目根目录到Python路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from utils.args_parser import parse_args
from plugins.trainers.pretrain_trainer import PretrainTrainer

def main():
    # 解析配置
    config = parse_args()
    
    # 创建输出目录
    os.makedirs(config["output_dir"], exist_ok=True)
    
    # 根据插件类型创建训练器
    trainer_plugin = config.get("trainer_plugin", "pretrain")
    
    if trainer_plugin == "pretrain":
        trainer = PretrainTrainer(config)
    else:
        raise ValueError(f"Unknown trainer plugin: {trainer_plugin}")
        
    # 开始训练
    trainer.train()

if __name__ == "__main__":
    main()